### Create regression tables for all models

date()
library("tidyverse")
library("maxLik")
sessionInfo()


load("results/fit_base.rda")
load("results/fit_joiners.rda")
load("results/fit_postwar.rda")
load("results/boot_base.rda")
load("results/boot_joiners.rda")
load("results/boot_postwar.rda")


## Read in coefficients and convert cost parameters to multiplier parameters
cost_to_mult <- function(dat) {
  out <- mutate_at(dat, vars(starts_with("gamma")), ~ .x * -1)
  gamma_cols <- str_detect(colnames(out), "^gamma")
  stopifnot(sum(gamma_cols) == 2)
  colnames(out) <- str_replace(colnames(out), "^gamma", "beta")
  out
}
make_point <- function(fit) {
  map(fit, ~ coef(.x$fit)) %>%
    do.call("rbind", .) %>%
    as_tibble() %>%
    cost_to_mult()
}
make_boot <- function(boot) {
  map(boot, as_tibble) %>%
    bind_rows() %>%
    cost_to_mult()
}

cf_point_base <- make_point(fit_base)
cf_point_joiners <- make_point(fit_joiners)
cf_point_postwar <- make_point(fit_postwar)
cf_boot_base <- make_boot(boot_base)
cf_boot_joiners <- make_boot(boot_joiners)
cf_boot_postwar <- make_boot(boot_postwar)

make_reg_table <- function(cf_point, cf_boot, fit) {
  ## Calculate average coefficients across main model imputations
  dat_reg_table_orig <- cf_point %>%
    gather(key = "term", value = "value") %>%
    group_by(term) %>%
    summarise(est = mean(value))

  ## Calculate bootstrap standard errors and relevant quantiles for confidence intervals
  dat_boot_se <- cf_boot %>%
    gather(key = "term", value = "value") %>%
    group_by(term) %>%
    summarise(
      se = sd(value),
      q_lwr = quantile(value, 0.025),
      q_upr = quantile(value, 0.975)
    )
  dat_reg_table <- left_join(dat_reg_table_orig, dat_boot_se, by = "term")

  ## Calculate pivot confidence intervals
  dat_reg_table <- dat_reg_table %>%
    mutate(
      ci_lwr = 2 * est - q_upr,
      ci_upr = 2 * est - q_lwr
    )

  ## Add supplemental data for the regression table
  n_disp <- nrow(fit[[1]]$data_dispute)
  n_part <- nrow(fit[[1]]$data_participant)
  ll <- map_dbl(fit, ~ logLik(.x$fit))
  n_params <- nrow(dat_reg_table)
  aic <- 2 * n_params - 2 * ll
  bic <- log(n_disp) * n_params - 2 * ll
  dat_reg_table <- dat_reg_table %>%
    add_row(term = "n_disp", est = n_disp) %>%
    add_row(term = "n_part", est = n_part) %>%
    add_row(term = "loglik", est = mean(ll)) %>%
    add_row(term = "aic", est = mean(aic)) %>%
    add_row(term = "bic", est = mean(bic))

  pretty_term_names <- function(term) {
    recode(term,
      "log(gdp_pwt)" = "GDP",
      "log1p(pec)" = "Energy Consumption",
      "log1p(irst)" = "Iron and Steel Production",
      "log1p(tpop)" = "Total Population",
      "log1p(upop)" = "Urban Population",
      "nuclear" = "Nuclear Weapons",
      "log1p(distance)" = "Distance to Dispute",
      "log1p(pct_imports)" = "Import Percentage",
      "polity2" = "Democracy",
      "polity_a" = "Democracy",
      "majpow_a" = "Major Power",
      "log1p(py_alt)" = "Peace Years",
      "s_cinc" = "Interest Similarity",
      "contig" = "Contiguity",
      "(Intercept)" = " Intercept",
      "(Intercept_scl)" = " Intercept",
      "log(n_states_a)" = "Coalition Size",
      "n_disp" = "     Disputes",
      "n_part" = "    Participants",
      "loglik" = "   Log-likelihood",
      "aic" = "  AIC",
      "bic" = " BIC"
    )
  }

  pretty_category_names <- function(category) {
    recode(category,
      "economic" = "Force Multiplier $m_i/c_i$: Economic",
      "demographic" = "Force Multiplier $m_i/c_i$: Demographic",
      "political" = "Force Multiplier $m_i/c_i$: Political",
      "geopolitical" = "Force Multiplier $m_i/c_i$: Geopolitical",
      "loc" = "Audience Cost Mean $\\mu_K$",
      "scl" = "Audience Cost Scale $\\sigma_K$",
    )
  }

  rows_from_df <- function(category, dat) {
    category <- pretty_category_names(as.character(category))
    if (category != "bottom") {
      category <- str_glue("\\multicolumn{{4}}{{l}}{{\\textit{{{category}}}}} \\\\")
    } else {
      category <- NULL
    }
    txt <- with(dat, str_glue("{term} & {est} & {se} & {ci} \\\\"))
    c("\\midrule", category, txt)
  }

  ## Create table data with nice names and categories
  dat_reg_table <- dat_reg_table %>%
    mutate(
      term = if_else(term == "scl:(Intercept)", "scl:(Intercept_scl)", term),
      term = str_replace(term, "^(beta|gamma|loc|scl):", "")
    ) %>%
    filter(term != "aic", term != "bic") %>%
    mutate_at(c("est", "se", "ci_lwr", "ci_upr"), ~ sprintf("%.2f", .x)) %>%
    mutate_at(c("est", "ci_lwr", "ci_upr"), ~ str_replace(.x, "^-", "$-$")) %>%
    mutate(
      ci = str_c("[", ci_lwr, ", ", ci_upr, "]"),
      ci = if_else(ci_lwr == "NA", "", ci)
    ) %>%
    select(term, est, se, ci) %>%
    mutate(
      category = fct_collapse(
        term,
        demographic = c("log1p(tpop)", "log1p(upop)"),
        economic = c("log(gdp_pwt)", "log1p(irst)", "log1p(pec)", "log1p(pct_imports)"),
        political = c("polity2"),
        geopolitical = c("nuclear", "log1p(distance)"),
        loc = c("polity_a", "majpow_a", "contig", "s_cinc", "log1p(py_alt)", "(Intercept)"),
        scl = c("(Intercept_scl)", "log(n_states_a)"),
        bottom = c("n_disp", "n_part", "loglik")
      ),
      category = fct_relevel(category, "demographic", "economic", "political", "geopolitical", "loc", "scl", "Other"),
      est = if_else(category == "bottom" & term != "loglik", str_replace(est, "\\.[0-9]*$", ""), est),
      term = pretty_term_names(term)
    ) %>%
    mutate_if(is.character, ~ if_else(.x == "NA", "", .x)) %>%
    arrange(category, term) %>%
    group_by(category) %>%
    nest()

  reg_table_body <-
    with(dat_reg_table, map2(category, data, rows_from_df)) %>%
    flatten_chr()

  reg_table_full <-
    c(
      "\\begin{tabular}{lrrc}",
      "\\toprule",
      "Term & Estimate & Std.\\ Err.\\ & Conf.\\ Int.\\ \\\\",
      reg_table_body,
      "\\bottomrule",
      "\\end{tabular}"
    )

  reg_table_full
}

reg_table_base <- make_reg_table(cf_point_base, cf_boot_base, fit_base)
reg_table_joiners <- make_reg_table(cf_point_joiners, cf_boot_joiners, fit_joiners)
reg_table_postwar <- make_reg_table(cf_point_postwar, cf_boot_postwar, fit_postwar)

if (!dir.exists("tables"))
  dir.create("tables")
writeLines(reg_table_base, con = "tables/table_3.tex")
writeLines(reg_table_joiners, con = "tables/table_A4.tex")
writeLines(reg_table_postwar, con = "tables/table_A9.tex")


date()
